{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "oss_mnist_reconstruction.ipynb",
      "provenance": [
        {
          "file_id": "1885p5OoBI1cbN8k57G8gCeXNqDN8Xpu9",
          "timestamp": 1659992477770
        },
        {
          "file_id": "11bLgt7q-TRwo-W65n1KiC55oqsgw9tw1",
          "timestamp": 1659712832740
        }
      ],
      "last_runtime": {
        "build_target": "//learning/deepmind/public/tools/ml_python:ml_notebook",
        "kind": "private"
      },
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    }
  },
  "cells": [
    {
      "cell_type": "markdown",
      "source": [
        "Copyright 2022 Authors. SPDX-License-Identifier: Apache-2.0"
      ],
      "metadata": {
        "id": "wtxoATIFv-QS"
      }
    },
    {
      "cell_type": "code",
      "source": [
        "#@title LICENSE\n",
        " \n",
        "# Licensed under the Apache License, Version 2.0 (the \"License\");\n",
        "# you may not use this file except in compliance with the License.\n",
        "# You may obtain a copy of the License at\n",
        "#\n",
        "# https://www.apache.org/licenses/LICENSE-2.0\n",
        "#\n",
        "# Unless required by applicable law or agreed to in writing, software\n",
        "# distributed under the License is distributed on an \"AS IS\" BASIS,\n",
        "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n",
        "# See the License for the specific language governing permissions and\n",
        "# limitations under the License."
      ],
      "metadata": {
        "cellView": "form",
        "id": "V9E-UpJKv8Nx"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bP8O1Gw-pL_4"
      },
      "source": [
        "import jax\n",
        "import functools\n",
        "import jax.numpy as jnp\n",
        "import flax.linen as nn\n",
        "import math\n",
        "\n",
        "from absl import logging\n",
        "from flax import linen as nn\n",
        "import gin\n",
        "import jax\n",
        "import jax.numpy as jnp\n",
        "import collections\n",
        "import tensorflow as tf \n",
        "import os.path as osp\n",
        "import pickle\n",
        "import flax\n",
        "\n",
        "import typing_extensions\n",
        "from typing import Any, Tuple, Optional\n",
        "from acme import types\n",
        "import dataclasses\n",
        "import optax\n",
        "\n",
        "import numpy as onp\n"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "import tensorflow_datasets as tfds\n",
        "import numpy as np\n",
        "\n",
        "from clu import checkpoint\n",
        "\n",
        "import matplotlib.pyplot as plt\n",
        "%matplotlib inline"
      ],
      "metadata": {
        "id": "AzPoEMymowru"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def get_mnist_data(num_samples: int = 60000):\n",
        "  \"\"\"Returns MNIST data as a matrix.\"\"\"\n",
        "  ds = tfds.load('mnist:3.*.*', split='train')\n",
        "  ds = ds.batch(num_samples)\n",
        "  data = next(ds.as_numpy_iterator())\n",
        "  X = np.reshape(data['image'], (num_samples, -1)) / 255.\n",
        "  # Returns a matrix of size `784 x num_samples`\n",
        "  mnist_matrix = X.T\n",
        "  mnist_matrix -= np.mean(mnist_matrix, axis=1, keepdims=True)\n",
        "  return mnist_matrix\n",
        "\n",
        "def get_mean_training_image():\n",
        "  ds = tfds.load('mnist:3.*.*', split='train')\n",
        "  ds = ds.batch(num_samples)\n",
        "  data = next(ds.as_numpy_iterator())\n",
        "  X = np.reshape(data['image'], (num_samples, -1)) / 255.\n",
        "  # Returns a matrix of size `784 x num_samples`\n",
        "  mnist_matrix = X.T\n",
        "  return np.mean(mnist_matrix, axis=1, keepdims=True)\n",
        "\n",
        "@jax.jit\n",
        "def f(X):\n",
        "  left_svd, sigma, _ = jnp.linalg.svd(X, full_matrices=False)\n",
        "  return left_svd, sigma"
      ],
      "metadata": {
        "id": "XAIBrqIGleL9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "num_samples = 10000\n",
        "ds = tfds.load('mnist:3.*.*', split='test')\n",
        "ds = ds.batch(num_samples)\n",
        "data = next(ds.as_numpy_iterator())\n",
        "labels = data['label']\n",
        "X = np.reshape(data['image'], (num_samples, -1)) / 255.\n",
        "# # Returns a matrix of size `784 x num_samples`\n",
        "test_matrix = X.T\n",
        "mean_training_image = get_mean_training_image()\n",
        "test_matrix -= mean_training_image"
      ],
      "metadata": {
        "id": "sE4CfBc1nfqe"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def plot_image(image, ax=None):\n",
        "  img = mean_training_image[:, 0] + image\n",
        "  img = img.reshape(28, -1)\n",
        "  if ax is None:\n",
        "    ax = plt.imshow(img)\n",
        "  else:\n",
        "    ax.imshow(img)\n",
        "  ax.axes.xaxis.set_ticks([])\n",
        "  ax.axes.yaxis.set_ticks([])\n",
        "  return ax"
      ],
      "metadata": {
        "id": "Oi0I6uWIox3J"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "ax = plot_image(test_matrix[:, 0])"
      ],
      "metadata": {
        "id": "6E497GKMtto2"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Load the true subspace\n",
        "\n",
        "## Add loading code.\n",
        "\n",
        "true_subspace_d = true_subspace[:, :16]"
      ],
      "metadata": {
        "id": "7-0St-kWpzeJ",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "true_subspace_d.shape"
      ],
      "metadata": {
        "id": "Wx1bHRvEp-AI"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "#@title Generate test images\n",
        "\n",
        "label_to_img = collections.defaultdict(list)\n",
        "for i, label in enumerate(labels):\n",
        "  label_to_img[label].append(test_matrix[:, i])\n",
        "\n",
        "for label, v in label_to_img.items():\n",
        "  label_to_img[label] = np.array(v)\n",
        "\n",
        "num_images_to_keep = 100\n",
        "images_to_test = []\n",
        "for label in sorted(label_to_img):\n",
        "  v = label_to_img[label]\n",
        "  images_to_test.append(v[:num_images_to_keep])\n",
        "\n",
        "images_to_test = np.concatenate(images_to_test, axis=0)\n",
        "images_to_test = images_to_test.T"
      ],
      "metadata": {
        "id": "X3TG8zhYushD",
        "cellView": "form"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "@jax.jit\n",
        "def solve_reconstruction(subspace, images):\n",
        "  x, residuals, _, _ = jnp.linalg.lstsq(subspace, images, rcond=None)\n",
        "  reconstructions = subspace @ x\n",
        "  return residuals, reconstructions\n",
        "\n",
        "def generate_labels_to_vals(residuals):\n",
        "  vals = np.argsort(residuals).to_py()\n",
        "  labels_to_vals = collections.defaultdict(list)\n",
        "  indices = [val // num_images_to_keep for val in vals]\n",
        "  for idx, val in zip(indices, vals):\n",
        "    labels_to_vals[idx].append(val)\n",
        "  return labels_to_vals"
      ],
      "metadata": {
        "id": "yEyHgjo7xRvH"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "residuals, reconstructions = solve_reconstruction(true_subspace_d, images_to_test)\n",
        "true_labels_to_vals = generate_labels_to_vals(residuals)\n",
        "print('Residuals:', jnp.mean(residuals))"
      ],
      "metadata": {
        "id": "XdBqJfDfqJ8T"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "fig, axes = plt.subplots(3, 10, figsize=(10, 3.0))\n",
        "\n",
        "titles = ['Ground \\nTruth', 'True \\n   Subspace', 'Lissa \\n(320 pixels)']\n",
        "\n",
        "for i in range(3):\n",
        "  if i == 0:\n",
        "    matrix_to_use = images_to_test\n",
        "\n",
        "  elif i == 1:\n",
        "    matrix_to_use = reconstructions\n",
        "  else:\n",
        "    matrix_to_use = lissa_reconstructions\n",
        "  for label in range(10):\n",
        "    img = matrix_to_use[:, true_labels_to_vals[label][1]]\n",
        "    ax = axes[i][label]\n",
        "    plot_image(img, ax=ax)\n",
        "    if label == 0:\n",
        "      ax.set_ylabel(titles[i], size=13) #  rotation=0,  labelpad=46)\n",
        "      # ax.set_xticklabels([])\n",
        "      # ax.set_yticklabels([])\n",
        "\n",
        "fig.subplots_adjust(hspace = 0.0, wspace=0.0, bottom=0, top=0.01)\n",
        "fig.tight_layout()"
      ],
      "metadata": {
        "id": "aKJzOz6h0c8f"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}